{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor能够广播需要满足的规则"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2, 3])\n"
     ]
    }
   ],
   "source": [
    "#所有维度的Size都相等，不需要广播即可以参与正常的运算\n",
    "x = torch.ones(1,2,3)\n",
    "y = torch.zeros(1,2,3)\n",
    "print((x+y).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[5.],\n",
      "        [7.]])\n"
     ]
    }
   ],
   "source": [
    "#至少有一个维度\n",
    "x = torch.Tensor([2])\n",
    "y = torch.Tensor([[3],[5]])\n",
    "#x将自动扩展为[[2],[2]],然后参与运算\n",
    "print(x+y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 4, 5, 6, 7])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (6) must match the size of tensor b (2) at non-singleton dimension 4",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-15-1e71fe357857>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m6\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m    \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m: The size of tensor a (6) must match the size of tensor b (2) at non-singleton dimension 4"
     ]
    }
   ],
   "source": [
    "#从最里面维度开始遍历，要么维度Size相等，要么其中一个Size为1，要么其中一个维度缺失。三者满足其一便可自动广播\n",
    "#1楼：7==7相等\n",
    "#2楼：其中一个size==1\n",
    "#3楼：其中一个size==1\n",
    "#4楼：4==4相等\n",
    "#5楼：3==3相等\n",
    "#6楼：其中一个tensor维度缺失\n",
    "x = torch.randn(2,3,4,5,6,7)\n",
    "y =    torch.rand(3,4,1,1,7)\n",
    "print((x+y).shape)\n",
    "#下面的Tensor不能广播，因此不能正常参与运算\n",
    "a = torch.randn(2,3,4,5,6,7)\n",
    "b =    torch.rand(3,4,1,2,7)\n",
    "print((a+b))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 两个能够自动广播的Tensor计算结果各个维度Size计算规则"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#如果两个Tensor的维度不相等，将自动在维度低的Tensor外包size为1的维度，直到和高的Tensor的维度相等\n",
    "#对于每个维度的Size，取两个Tensor的最大值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3])\n",
      "torch.Size([1, 4, 2, 3])\n",
      "torch.Size([1, 4, 2, 3])\n"
     ]
    }
   ],
   "source": [
    "#2*3\n",
    "x = torch.randn(2,3)\n",
    "print(x.shape)\n",
    "#1*4*2*3\n",
    "y = torch.rand(1,4,2,3)\n",
    "print(y.shape)\n",
    "#x维度为2，y维度为4，因此会在x外增加size为1的维度变成[[[[1,2,3],[1,2,3]]]]\n",
    "#维度相等后对应维度按照‘广播规则’进行广播，是的size相等。\n",
    "#输出结果：维度等于高维度tensor的维度，每个维度size取两个tensor对应维度的最大值\n",
    "print((x+y).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1, 2, 3])\n"
     ]
    }
   ],
   "source": [
    "print(torch.Tensor([[[[1,2,3],[1,2,3]]]]).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
